{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepping MobileCLIP model for use in Ente\n",
    "\n",
    "[Paper](https://arxiv.org/pdf/2311.17049.pdf) | [Github](https://github.com/apple/ml-mobileclip)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up Pytorch weights and source code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !mkdir mobileclip_repo\n",
    "# %cd mobileclip_repo\n",
    "# !git clone https://github.com/apple/ml-mobileclip.git\n",
    "# %cd ml-mobileclip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd mobileclip_repo/ml-mobileclip/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !source get_pretrained_models.sh   # Files will be downloaded to `checkpoints` directory.\n",
    "# %cd ../.."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!uv pip install clip-benchmark>=1.4.0 datasets>=2.8.0 open-clip-torch>=2.20.0 timm>=0.9.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.onnx\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import mobileclip\n",
    "import numpy as np\n",
    "from numpy.linalg import norm\n",
    "import onnx\n",
    "import onnxruntime as ort\n",
    "print(ort.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='checkpoints/mobileclip_s2.pt')\n",
    "og_model = model\n",
    "model.eval()\n",
    "og_model.eval()\n",
    "tokenizer = mobileclip.get_tokenizer('mobileclip_s2')\n",
    "\n",
    "image = preprocess(Image.open(\"docs/fig_accuracy_latency.png\").convert('RGB')).unsqueeze(0)\n",
    "text = tokenizer([\"Hello World!\", \"a diagram\", \"a dog\", \"a cat\"])\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast():\n",
    "    image_features = model.encode_image(image)\n",
    "    text_features = model.encode_text(text)\n",
    "    image_features /= image_features.norm(dim=-1, keepdim=True)\n",
    "    text_features /= text_features.norm(dim=-1, keepdim=True)\n",
    "\n",
    "    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n",
    "\n",
    "print(\"Label probs:\", text_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !rm -rf mobileclip_repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer([\"This is a tokenized string\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_input = tokenizer([\"Hello World! This is a super duper long piece of text of at least 77 tokens, purely to make sure that indeed this is a good input without any zeros that the exporter might somehow confuse with a boolean. Apparently we're still not at 77 tokens, so I just keep on monkey typing this story in the hope that someday I have a fully tokenized string of text that is longer than the required 77 tokens. Thank you for coming to my TED talk.\"])\n",
    "text_emb = model.encode_text(text_input)[0].detach().numpy()\n",
    "text_emb /= norm(text_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_singapore = Image.open(\"../data/singapore.jpg\").convert('RGBA')\n",
    "image_input = preprocess(image_singapore).unsqueeze(0)\n",
    "print(image_input.detach().numpy().shape)\n",
    "print(1*3*256*256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_emb = model(image_input[:,:3,:,:])[0][0].detach().numpy()\n",
    "print(image_emb.shape)\n",
    "print(norm(image_emb))\n",
    "image_emb[0:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_singapore_onnx = np.array(image_singapore)\n",
    "print(image_singapore_onnx.shape)\n",
    "print(image_singapore_onnx.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_opset = 18  # use opset 18 for Resize to antialias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncodeImageWrapper(nn.Module):\n",
    "    def __init__(self, original_model):\n",
    "        super(EncodeImageWrapper, self).__init__()\n",
    "        self.original_model = original_model\n",
    "\n",
    "    def forward(self, input):\n",
    "        return self.original_model.encode_image(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_model_wrapper = EncodeImageWrapper(model)\n",
    "image_model_wrapper.eval()\n",
    "image_model_wrapper.original_model.eval()\n",
    "clip_image_onnx_export_path = \"onnx_models/mobileclip_s2_image_float32.onnx\"\n",
    "torch.onnx.export(image_model_wrapper, image, clip_image_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=[\"input\"], output_names=[\"output\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "mobileclip_image_onnx = onnx.load(clip_image_onnx_export_path)\n",
    "onnx.checker.check_model(mobileclip_image_onnx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncodeTextWrapper(nn.Module):\n",
    "    def __init__(self, original_model):\n",
    "        super(EncodeTextWrapper, self).__init__()\n",
    "        self.original_model = original_model\n",
    "\n",
    "    def forward(self, input):\n",
    "        return self.original_model.encode_text(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_model_wrapper = EncodeTextWrapper(model)\n",
    "text_model_wrapper.eval()\n",
    "text_model_wrapper.original_model.eval()\n",
    "clip_text_onnx_export_path = \"onnx_models/mobileclip_s2_text_int64.onnx\"\n",
    "torch.onnx.export(text_model_wrapper, text_input, clip_text_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=['input'], output_names=['output'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Altering ONNX models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change input name to `og_input` so we can reserve `input` for altered model that includes preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "og_input = onnx.helper.make_tensor_value_info(\n",
    "    name=\"og_input\",\n",
    "    elem_type=onnx.TensorProto.FLOAT,\n",
    "    shape=[1, 3, 256, 256],  \n",
    ")\n",
    "\n",
    "# Update the input names in the rest of the model\n",
    "for node in mobileclip_image_onnx.graph.node:\n",
    "    for i, input_name in enumerate(node.input):\n",
    "        if input_name == \"input\":\n",
    "            node.input[i] = \"og_input\"\n",
    "\n",
    "graph = onnx.helper.make_graph(\n",
    "    nodes=mobileclip_image_onnx.graph.node,\n",
    "    name=mobileclip_image_onnx.graph.name,\n",
    "    inputs=[og_input],\n",
    "    outputs=mobileclip_image_onnx.graph.output,\n",
    "    initializer=mobileclip_image_onnx.graph.initializer,\n",
    ")\n",
    "mobileclip_image_onnx = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)])\n",
    "onnx.save_model(mobileclip_image_onnx, clip_image_onnx_export_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add preprocessing to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor, create_named_value, Resize, ImageBytesToFloat, Unsqueeze, CenterCrop, Debug, ChannelsLastToChannelsFirst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [create_named_value(\"input_to_process\", onnx.TensorProto.UINT8, [\"H\", \"W\", \"C\"])]\n",
    "\n",
    "pipeline = PrePostProcessor(inputs, onnx_opset)\n",
    "\n",
    "pipeline.add_pre_processing(\n",
    "    [\n",
    "        Resize(256),  \n",
    "        CenterCrop(256, 256),  # Crop to 256x256. NOTE: Currently only HWC input is handled.\n",
    "        ChannelsLastToChannelsFirst(),  # Convert to CHW\n",
    "        # Debug(),\n",
    "        ImageBytesToFloat(),  # Convert to float in range 0..1 by dividing uint8 values by 255\n",
    "        # Debug(),\n",
    "        Unsqueeze([0]),  # add batch, CHW --> 1CHW\n",
    "        # Debug(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "clip_image_with_preprocessing = pipeline.run(mobileclip_image_onnx)\n",
    "\n",
    "onnx.checker.check_model(clip_image_with_preprocessing)\n",
    "clip_image_onnx_rgb_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgb.onnx\"\n",
    "new_model_path = clip_image_onnx_rgb_path\n",
    "onnx.save_model(clip_image_with_preprocessing, new_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add a slice node so that the model can take raw RGBA data as input (as well as standard RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = clip_image_with_preprocessing\n",
    "\n",
    "# Create a new input with flexible channel dimension\n",
    "new_input = onnx.helper.make_tensor_value_info(\n",
    "    name=\"input\",\n",
    "    elem_type=onnx.TensorProto.UINT8,\n",
    "    shape=[\"H\", \"W\", \"C\"],  \n",
    ")\n",
    "\n",
    "# Create constant tensors for starts, ends, and axes\n",
    "starts_tensor = onnx.helper.make_tensor(\n",
    "    name=\"starts\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([0], dtype=np.int64)\n",
    ")\n",
    "ends_tensor = onnx.helper.make_tensor(\n",
    "    name=\"ends\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([3], dtype=np.int64)\n",
    ")\n",
    "axes_tensor = onnx.helper.make_tensor(\n",
    "    name=\"axes\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([2], dtype=np.int64)\n",
    ")\n",
    "new_initializers = [starts_tensor, ends_tensor, axes_tensor] + list(onnx_model.graph.initializer)\n",
    "slice_node = onnx.helper.make_node(\n",
    "    \"Slice\",\n",
    "    inputs=[\"input\", \"starts\", \"ends\", \"axes\"],\n",
    "    outputs=[\"sliced_input\"],\n",
    "    name=\"slice_rgba_input_node\"\n",
    ")\n",
    "\n",
    "\n",
    "# Add the new input and Slice node to the graph\n",
    "graph = onnx.helper.make_graph(\n",
    "    [slice_node] + list(onnx_model.graph.node),  # Prepend Slice node to existing nodes\n",
    "    onnx_model.graph.name,\n",
    "    [new_input],\n",
    "    list(onnx_model.graph.output),\n",
    "    initializer=new_initializers,\n",
    "    value_info=onnx_model.graph.value_info,\n",
    ")\n",
    "\n",
    "# Create the new model\n",
    "mobileclip_image_onnx_rgba = onnx.helper.make_model(\n",
    "    graph,\n",
    "    opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
    ")\n",
    "\n",
    "\n",
    "# Update the input names in the rest of the model\n",
    "for node in mobileclip_image_onnx_rgba.graph.node:\n",
    "    for i, input_name in enumerate(node.input):\n",
    "        if input_name == \"input_to_process\":\n",
    "            node.input[i] = \"sliced_input\"\n",
    "\n",
    "# Save the new model\n",
    "onnx.checker.check_model(mobileclip_image_onnx_rgba)\n",
    "clip_image_onnx_rgba_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba.onnx\"\n",
    "onnx.save(mobileclip_image_onnx_rgba, clip_image_onnx_rgba_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_image_sim_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_sim.onnx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!onnxsim {clip_image_onnx_rgba_path} {clip_image_sim_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_opt_sess_options = ort.SessionOptions()\n",
    "\n",
    "image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL\n",
    "image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC\n",
    "\n",
    "clip_image_opt_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_opt.onnx\"\n",
    "image_opt_sess_options.optimized_model_filepath = clip_image_opt_path\n",
    "\n",
    "opt_image_session = ort.InferenceSession(clip_image_sim_path, image_opt_sess_options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add metadata to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_image_opt = onnx.load(clip_image_opt_path)\n",
    "clip_image_opt.producer_name = \"EnteMobileCLIPImageEncoder\"\n",
    "clip_image_opt.doc_string = \"MobileCLIP S2 Image Encoder with built-in preprocessing. Accepts both RGB and RGBA raw bytes input (uint8) in HWC format.\"\n",
    "clip_image_opt.graph.doc_string = \"\"\n",
    "clip_image_opt.graph.name = \"SliceRGB+Resize+CenterCrop+ToFloat+Unsqueeze+MobileCLIP_S2_ImageEncoder\"\n",
    "onnx.save(clip_image_opt, clip_image_opt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_session = ort.InferenceSession(clip_image_opt_path)\n",
    "onnx_emb = ort_session.run(None, {\"input\": image_singapore_onnx})[0][0]\n",
    "onnx_emb /= norm(onnx_emb)\n",
    "np.dot(image_emb, onnx_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm {clip_image_onnx_export_path}\n",
    "!rm {clip_image_onnx_rgb_path}\n",
    "!rm {clip_image_onnx_rgba_path}\n",
    "!rm {clip_image_sim_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure the model can use int32 as input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "mobileclip_text_onxx = onnx.load(clip_text_onnx_export_path)\n",
    "\n",
    "for tensor in mobileclip_text_onxx.graph.input:\n",
    "    if tensor.name == \"input\":\n",
    "        tensor.type.tensor_type.elem_type = onnx.TensorProto.INT32\n",
    "        break\n",
    "\n",
    "# Save the modified model\n",
    "clip_text_onnx_int32_path = \"onnx_models/mobileclip_s2_text_int32.onnx\"\n",
    "onnx.save(mobileclip_text_onxx, clip_text_onnx_int32_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Simplify](https://github.com/daquexian/onnx-simplifier) the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_text_sim_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_sim.onnx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!onnxsim {clip_text_onnx_int32_path} {clip_text_sim_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apply basic offline [graph optimizations](https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). Only do the basic optimizations offline, the extended and layout optimizations should be done online depending on execution provider and hardware."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_opt_sess_options = ort.SessionOptions()\n",
    "\n",
    "text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL\n",
    "text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC\n",
    "\n",
    "clip_text_opt_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_opt.onnx\"\n",
    "text_opt_sess_options.optimized_model_filepath = clip_text_opt_path\n",
    "\n",
    "opt_text_session = ort.InferenceSession(clip_text_sim_path, text_opt_sess_options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add metadata to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_text_opt = onnx.load(clip_text_opt_path)\n",
    "clip_text_opt.producer_name = \"EnteMobileCLIPTextEncoder\"\n",
    "clip_text_opt.doc_string = \"MobileCLIP S2 Text Encoder. Accepts an integer array (int32) of length 77. Longer arrays will be truncated.\"\n",
    "clip_text_opt.graph.doc_string = \"\"\n",
    "clip_text_opt.graph.name = \"MobileCLIP_S2_TextEncoder\"\n",
    "onnx.save(clip_text_opt, clip_text_opt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mobileclip_text_ort_sess = ort.InferenceSession(clip_text_opt_path)\n",
    "text_onnx_emb = mobileclip_text_ort_sess.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
    "text_onnx_emb /= norm(text_onnx_emb)\n",
    "np.dot(text_emb, text_onnx_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm {clip_text_onnx_export_path}\n",
    "!rm {clip_text_onnx_int32_path}\n",
    "!rm {clip_text_sim_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantize text model\n",
    "\n",
    "https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quantization pre-processing (not to confuse with normal pre-processing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxruntime.quantization import quant_pre_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_text_quantized_preprocessed_path = \"onnx_models/mobileclip_s2_text_quant_preprocessed.onnx\"\n",
    "quant_pre_process(clip_text_opt_path, clip_text_quantized_preprocessed_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dynamic quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_names = []\n",
    "matmul_nodes_names = []\n",
    "for node in clip_text_opt.graph.node:\n",
    "    node_names.append(node.name)\n",
    "    if node.op_type == \"MatMul\" and node.name != \"/text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul\":\n",
    "        matmul_nodes_names.append(node.name)\n",
    "len(node_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_text_quantized_dynamic_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_quant.onnx\"\n",
    "quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_path, nodes_to_exclude=node_names[28])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
    "text_onnx_quant_dyn_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
    "text_onnx_quant_dyn_emb /= norm(text_onnx_quant_dyn_emb)\n",
    "np.dot(text_onnx_quant_dyn_emb, text_onnx_emb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quantization Debugging (uncomment if you want to try it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exclude_amount = 1\n",
    "\n",
    "\n",
    "# for i in range(25, 30, exclude_amount):\n",
    "#     begin = i\n",
    "#     end = min(i+exclude_amount, len(node_names))\n",
    "    \n",
    "#     clip_text_quantized_dynamic_debug_path = f\"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int8dyn_opt_debug.onnx\"\n",
    "#     quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_debug_path, nodes_to_exclude=node_names[begin:end])\n",
    "#     mobileclip_text_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_text_quantized_dynamic_debug_path)\n",
    "#     text_onnx_quant_dyn_emb_debug = mobileclip_text_quant_dyn_ort_sess_debug.run([\"output\"], {\"input\": text_input.numpy().astype(\"int32\")})[0][0]\n",
    "#     text_onnx_quant_dyn_emb_debug /= norm(text_onnx_quant_dyn_emb_debug)\n",
    "#     sim_debug = np.dot(text_onnx_quant_dyn_emb_debug, text_onnx_emb)\n",
    "#     print(f\"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_names[28:29]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test on a dataset of image captions. Before continuing, download the dataset from [Kaggle](https://www.kaggle.com/datasets/aladdinpersson/flickr8kimagescaptions/data) and put it in the `../data` folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import copy\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "captions = []\n",
    "\n",
    "with open('../data/flickr8k_captions.txt', 'r', encoding='utf-8') as file:\n",
    "    csv_reader = csv.reader(file)\n",
    "    next(csv_reader)\n",
    "    for row in csv_reader:\n",
    "        captions.append(row[1])\n",
    "\n",
    "print(len(captions))\n",
    "print(captions[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test accuracy of quantized model quickly (uncomment code below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 600\n",
    "similarities = np.zeros(test_size)\n",
    "mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
    "\n",
    "for i, caption in tqdm(enumerate(captions[:test_size])):\n",
    "    text_input_test = tokenizer([caption])\n",
    "    text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()\n",
    "    text_emb_test /= norm(text_emb_test)\n",
    "    text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test.numpy().astype(\"int32\")})[0][0]\n",
    "    text_onnx_test_emb /= norm(text_onnx_test_emb)\n",
    "    similarities[i] = np.dot(text_onnx_test_emb, text_emb_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Mean similarity: {similarities.mean()}\")\n",
    "print(f\"Standard deviation: {similarities.std()}\")\n",
    "print(f\"Minimum similarity: {similarities.min()}\")\n",
    "print(f\"Maximum similarity: {similarities.max()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test accuracy of quantized model extensively (uncomment code below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# captions_extensive = copy.deepcopy(captions)\n",
    "\n",
    "# for i in range(10000):\n",
    "#     captions_extensive[i] = captions_extensive[i] + \" \" + captions_extensive[i + 10000] + \" \" + captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
    "#     captions_extensive[i + 10000] = captions_extensive[i + 10000] + \" \" + captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
    "#     captions_extensive[i + 20000] = captions_extensive[i + 20000] + \" \" + captions_extensive[i + 30000]\n",
    "# captions_extensive = captions_extensive[:40000]\n",
    "\n",
    "# test_size = len(captions_extensive)\n",
    "# similarities_extensive = np.zeros(test_size)\n",
    "# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
    "\n",
    "# for i, caption in tqdm(enumerate(captions_extensive[:test_size])):\n",
    "#     text_input_test = tokenizer([caption])\n",
    "#     text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()\n",
    "#     text_emb_test /= norm(text_emb_test)\n",
    "#     text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test.numpy().astype(\"int32\")})[0][0]\n",
    "#     text_onnx_test_emb /= norm(text_onnx_test_emb)\n",
    "#     similarities_extensive[i] = np.dot(text_onnx_test_emb, text_emb_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(f\"Mean similarity: {similarities_extensive.mean()}\")\n",
    "# print(f\"Standard deviation: {similarities_extensive.std()}\")\n",
    "# print(f\"Minimum similarity: {similarities_extensive.min()}\")\n",
    "# print(f\"Maximum similarity: {similarities_extensive.max()}\")\n",
    "# print(f\"Percentage of similarities above 0.99: {np.sum(similarities_extensive > 0.99) / len(similarities_extensive) * 100:.2f}%\")\n",
    "# print(f\"Percentage of similarities above 0.995: {np.sum(similarities_extensive > 0.995) / len(similarities_extensive) * 100:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Investigating the MatMul excluded from quantization to improve performance (uncomment code below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# quant_model = onnx.load(clip_text_opt_path)\n",
    "# node_name = node_names[28] # /text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul\n",
    "# # use_node_name = matmul_nodes_names[8]\n",
    "# use_node_name = node_name\n",
    "\n",
    "# # Find the MatMul node\n",
    "# special_matmul_node = None\n",
    "# for node in quant_model.graph.node:\n",
    "#     if node.op_type == 'MatMul' and node.name == use_node_name:\n",
    "#         special_matmul_node = node\n",
    "#         print(f\"MatMul node found: {special_matmul_node.name}\")\n",
    "#         break\n",
    "\n",
    "# if special_matmul_node is None:\n",
    "#     raise ValueError(f\"MatMul node with name '{use_node_name}' not found in the model.\")\n",
    "\n",
    "# # Get the weight tensor\n",
    "# weight_name = special_matmul_node.input[1]\n",
    "# special_weight_tensor = None\n",
    "# for init in quant_model.graph.initializer:\n",
    "#     if init.name == weight_name:\n",
    "#         special_weight_tensor = init\n",
    "#         break\n",
    "\n",
    "# if special_weight_tensor is None:\n",
    "#     raise ValueError(f\"Weight tensor for MatMul node '{use_node_name}' not found.\")\n",
    "\n",
    "# special_weight_array = onnx.numpy_helper.to_array(special_weight_tensor)\n",
    "\n",
    "# mean = np.mean(special_weight_array)\n",
    "# std = np.std(special_weight_array)\n",
    "# min_val = np.min(special_weight_array)\n",
    "# max_val = np.max(special_weight_array)\n",
    "\n",
    "# print(f\"Statistical Analysis for MatMul node '{use_node_name}':\")\n",
    "# print(f\"Mean: {mean}\")\n",
    "# print(f\"Standard Deviation: {std}\")\n",
    "# print(f\"Minimum: {min_val}\")\n",
    "# print(f\"Maximum: {max_val}\")\n",
    "# print(f\"Dynamic Range: {max_val - min_val}\")\n",
    "\n",
    "# plt.figure(figsize=(10, 6))\n",
    "# plt.hist(special_weight_array.flatten(), bins=50, edgecolor='black')\n",
    "# plt.title(f\"Histogram of Weights for MatMul node '{use_node_name}'\")\n",
    "# plt.xlabel(\"Weight Value\")\n",
    "# plt.ylabel(\"Frequency\")\n",
    "# plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test speed of quantized model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "# time_test_size = 1000\n",
    "# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)\n",
    "# times_unquantized = np.zeros(time_test_size)\n",
    "# times_quantized = np.zeros(time_test_size)\n",
    "\n",
    "# # Time of unquantized model\n",
    "# print(\"Timing unquantized model...\")\n",
    "# for i, caption in tqdm(enumerate(captions[:time_test_size])):\n",
    "#     text_input_test = tokenizer([caption])\n",
    "#     start = time.time()\n",
    "#     _ = model.encode_text(text_input_test)\n",
    "#     end = time.time()\n",
    "#     times_unquantized[i] = end - start\n",
    "\n",
    "# # Time of quantized model\n",
    "# print(\"Timing quantized model...\")\n",
    "# for i, caption in tqdm(enumerate(captions[:time_test_size])):\n",
    "#     text_input_test = tokenizer([caption]).numpy().astype(\"int32\")\n",
    "#     start = time.time()\n",
    "#     _ = mobileclip_text_quant_dyn_ort_sess.run([\"output\"], {\"input\": text_input_test})\n",
    "#     end = time.time()\n",
    "#     times_quantized[i] = end - start\n",
    "\n",
    "# original_mean = times_unquantized.mean()\n",
    "# original_std = times_unquantized.std()\n",
    "# quantized_mean = times_quantized.mean()\n",
    "# quantized_std = times_quantized.std()\n",
    "\n",
    "# print(f\"Original model: {original_mean:.6f} ± {original_std:.6f} seconds\")\n",
    "# print(f\"Quantized model: {quantized_mean:.6f} ± {quantized_std:.6f} seconds\")\n",
    "# print(f\"Speedup: {original_mean / quantized_mean:.2f}x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm {clip_text_quantized_preprocessed_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantizing image model\n",
    "\n",
    "Eventually got it to roughly 0.996 similarity with the original model, at a reduction of 54MB, from 143 to 89MB. Also not bad, but since it's less of a reduction and the resulting embeddings will be stored permanently we decided not to use it. Uncomment code below to restart investigation if wanted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image_node_names = []\n",
    "# image_matmul_nodes_names = []\n",
    "# image_conv_nodes_names = []\n",
    "# for node in clip_image_opt.graph.node:\n",
    "#     image_node_names.append(node.name)\n",
    "#     if node.op_type == \"MatMul\":\n",
    "#         image_matmul_nodes_names.append(node.name)\n",
    "#     if node.op_type == \"Conv\":\n",
    "#         image_conv_nodes_names.append(node.name)\n",
    "# print(len(image_node_names))\n",
    "# print(len(image_matmul_nodes_names))\n",
    "# print(len(image_conv_nodes_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clip_image_quantized_dynamic_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8_opt.onnx\"\n",
    "# exclude = list(set(image_node_names[:100] + image_conv_nodes_names))\n",
    "# quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)\n",
    "\n",
    "# mobileclip_image_quant_dyn_ort_sess = ort.InferenceSession(clip_image_quantized_dynamic_path)\n",
    "# image_onnx_quant_dyn_emb = mobileclip_image_quant_dyn_ort_sess.run([\"output\"], {\"input\": image_singapore_onnx})[0][0]\n",
    "# image_onnx_quant_dyn_emb /= norm(image_onnx_quant_dyn_emb)\n",
    "# np.dot(image_onnx_quant_dyn_emb, image_emb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Debug quantizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exclude_amount = 50\n",
    "# exclude_for_sure = image_node_names[:100] + image_node_names[225:260] + image_node_names[280:300] + image_node_names[430:480] + image_node_names[510:560] + image_node_names[650:]\n",
    "\n",
    "# image_test_quant = Image.open(\"../data/singapore.jpg\").convert('RGB')\n",
    "# image_test_quant_onnx = np.array(image_test_quant)\n",
    "\n",
    "# clip_image_opt_sess = ort.InferenceSession(clip_image_opt_path)\n",
    "# onnx_emb_quant_test = clip_image_opt_sess.run(None, {\"input\": image_test_quant_onnx})[0][0]\n",
    "# onnx_emb_quant_test /= norm(onnx_emb_quant_test)\n",
    "\n",
    "\n",
    "# for i in range(550, 600, exclude_amount):\n",
    "#     begin = i\n",
    "#     end = min(i+exclude_amount, len(image_node_names))\n",
    "#     exclude = list(set(exclude_for_sure + image_node_names[begin:end]))\n",
    "    \n",
    "#     clip_image_quantized_dynamic_debug_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8dyn_opt_debug.onnx\"\n",
    "#     quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_debug_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)\n",
    "#     mobileclip_image_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_image_quantized_dynamic_debug_path)\n",
    "#     image_onnx_quant_dyn_emb_debug = mobileclip_image_quant_dyn_ort_sess_debug.run([\"output\"], {\"input\": image_test_quant_onnx})[0][0]\n",
    "#     image_onnx_quant_dyn_emb_debug /= norm(image_onnx_quant_dyn_emb_debug)\n",
    "#     sim_debug = np.dot(image_onnx_quant_dyn_emb_debug, onnx_emb_quant_test)\n",
    "#     print(f\"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Float16 conversion for Image model\n",
    "\n",
    "https://onnxruntime.ai/docs/performance/model-optimizations/float16.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxconverter_common import convert_float_to_float16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "check_nodes_names = []\n",
    "skip_nodes_names = []\n",
    "try_image_model = onnx.load(clip_image_opt_path)\n",
    "for node in try_image_model.graph.node:\n",
    "    check_nodes_names.append(node.name)\n",
    "preprocess_nodes = check_nodes_names[:25]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_image_fp16 = convert_float_to_float16(try_image_model, keep_io_types=True, disable_shape_infer=True, node_block_list=preprocess_nodes)\n",
    "clip_image_fp16_path = f\"onnx_models/mobileclip_s2_image_opset{onnx_opset}_fp16.onnx\"\n",
    "onnx.save(clip_image_fp16, clip_image_fp16_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_onnx_input = np.array(Image.open(\"../data/singapore.jpg\").convert('RGB'))\n",
    "try_sess_options = ort.SessionOptions()\n",
    "try_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "# try_sess_options.inter_op_num_threads = 0\n",
    "# try_sess_options.intra_op_num_threads = 0\n",
    "# try_sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL\n",
    "# try_sess_options.enable_profiling = True\n",
    "# try_sess_options.log_severity_level = 0 # Verbose\n",
    "clip_image_fp16_sess = ort.InferenceSession(clip_image_fp16_path, try_sess_options)\n",
    "clip_image_sess = ort.InferenceSession(clip_image_opt_path, try_sess_options)\n",
    "image_onnx_fp16_emb = clip_image_fp16_sess.run([\"output\"], {\"input\": image_onnx_input})[0][0]\n",
    "image_onnx_fp16_emb /= norm(image_onnx_fp16_emb)\n",
    "image_onnx_emb = clip_image_sess.run([\"output\"], {\"input\": image_onnx_input})[0][0]\n",
    "image_onnx_emb /= norm(image_onnx_emb)\n",
    "print(np.dot(image_onnx_fp16_emb, image_onnx_emb))\n",
    "print(image_onnx_emb[0:5])\n",
    "print(image_onnx_fp16_emb[0:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test speed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_test_size = 100\n",
    "\n",
    "begin_time_fp16 = time.time()\n",
    "for i in tqdm(range(time_test_size)):\n",
    "    _ = clip_image_fp16_sess.run([\"output\"], {\"input\": image_onnx_input})\n",
    "end_time_fp16 = time.time()\n",
    "time_fp16 = end_time_fp16 - begin_time_fp16\n",
    "\n",
    "begin_time_opt = time.time()\n",
    "for i in tqdm(range(time_test_size)):\n",
    "    _ = clip_image_sess.run([\"output\"], {\"input\": image_onnx_input})\n",
    "end_time_opt = time.time()\n",
    "time_opt = end_time_opt - begin_time_opt\n",
    "\n",
    "\n",
    "\n",
    "print(f\"Optimized model: {time_opt:.6f} seconds, so {time_opt / time_test_size:.6f} seconds per inference\")\n",
    "print(f\"FP16 model: {time_fp16:.6f} seconds, so {time_fp16 / time_test_size:.6f} seconds per inference\")\n",
    "print(f\"Speed difference FP16: {time_opt / time_fp16:.2f}x\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ente_clip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
